<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - rbf_network.h</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2008  Davis E. King (davis@dlib.net)
</font><font color='#009900'>// License: Boost Software License   See LICENSE.txt for the full license.
</font><font color='#0000FF'>#ifndef</font> DLIB_RBf_NETWORK_
<font color='#0000FF'>#define</font> DLIB_RBf_NETWORK_

<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../matrix.h.html'>../matrix.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='rbf_network_abstract.h.html'>rbf_network_abstract.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='kernel.h.html'>kernel.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='linearly_independent_subset_finder.h.html'>linearly_independent_subset_finder.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='function.h.html'>function.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../algs.h.html'>../algs.h</a>"

<font color='#0000FF'>namespace</font> dlib
<b>{</b>

<font color='#009900'>// ------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
        <font color='#0000FF'>typename</font> Kern 
        <font color='#5555FF'>&gt;</font>
    <font color='#0000FF'>class</font> <b><a name='rbf_network_trainer'></a>rbf_network_trainer</b> 
    <b>{</b>
        <font color='#009900'>/*!
            This is an implementation of an RBF network trainer that follows
            the directions right off Wikipedia basically.  So nothing 
            particularly fancy.  Although the way the centers are selected
            is somewhat unique.
        !*/</font>

    <font color='#0000FF'>public</font>:
        <font color='#0000FF'>typedef</font> Kern kernel_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> kernel_type::scalar_type scalar_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> kernel_type::sample_type sample_type;
        <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> kernel_type::mem_manager_type mem_manager_type;
        <font color='#0000FF'>typedef</font> decision_function<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> trained_function_type;

        <b><a name='rbf_network_trainer'></a>rbf_network_trainer</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> :
            num_centers<font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font>
        <b>{</b>
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_kernel'></a>set_kernel</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> kernel_type<font color='#5555FF'>&amp;</font> k
        <font face='Lucida Console'>)</font>
        <b>{</b>
            kernel <font color='#5555FF'>=</font> k;
        <b>}</b>

        <font color='#0000FF'>const</font> kernel_type<font color='#5555FF'>&amp;</font> <b><a name='get_kernel'></a>get_kernel</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> kernel;
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='set_num_centers'></a>set_num_centers</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num 
        <font face='Lucida Console'>)</font>
        <b>{</b>
            num_centers <font color='#5555FF'>=</font> num;
        <b>}</b>

        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='get_num_centers'></a>get_num_centers</b> <font face='Lucida Console'>(</font>
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> num_centers;
        <b>}</b>

        <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
            <font color='#0000FF'>typename</font> in_sample_vector_type,
            <font color='#0000FF'>typename</font> in_scalar_vector_type
            <font color='#5555FF'>&gt;</font>
        <font color='#0000FF'>const</font> decision_function<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> <b><a name='train'></a>train</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> in_sample_vector_type<font color='#5555FF'>&amp;</font> x,
            <font color='#0000FF'>const</font> in_scalar_vector_type<font color='#5555FF'>&amp;</font> y
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>return</font> <font color='#BB00BB'>do_train</font><font face='Lucida Console'>(</font><font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font>x<font face='Lucida Console'>)</font>, <font color='#BB00BB'>mat</font><font face='Lucida Console'>(</font>y<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
        <b>}</b>

        <font color='#0000FF'><u>void</u></font> <b><a name='swap'></a>swap</b> <font face='Lucida Console'>(</font>
            rbf_network_trainer<font color='#5555FF'>&amp;</font> item
        <font face='Lucida Console'>)</font>
        <b>{</b>
            <font color='#BB00BB'>exchange</font><font face='Lucida Console'>(</font>kernel, item.kernel<font face='Lucida Console'>)</font>;
            <font color='#BB00BB'>exchange</font><font face='Lucida Console'>(</font>num_centers, item.num_centers<font face='Lucida Console'>)</font>;
        <b>}</b>

    <font color='#0000FF'>private</font>:

    <font color='#009900'>// ------------------------------------------------------------------------------------
</font>
        <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
            <font color='#0000FF'>typename</font> in_sample_vector_type,
            <font color='#0000FF'>typename</font> in_scalar_vector_type
            <font color='#5555FF'>&gt;</font>
        <font color='#0000FF'>const</font> decision_function<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> <b><a name='do_train'></a>do_train</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> in_sample_vector_type<font color='#5555FF'>&amp;</font> x,
            <font color='#0000FF'>const</font> in_scalar_vector_type<font color='#5555FF'>&amp;</font> y
        <font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
        <b>{</b>
            <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> decision_function<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font>::scalar_vector_type scalar_vector_type;

            <font color='#009900'>// make sure requires clause is not broken
</font>            <font color='#BB00BB'>DLIB_ASSERT</font><font face='Lucida Console'>(</font><font color='#BB00BB'>is_learning_problem</font><font face='Lucida Console'>(</font>x,y<font face='Lucida Console'>)</font>,
                "<font color='#CC0000'>\tdecision_function rbf_network_trainer::train(x,y)</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t invalid inputs were given to this function</font>"
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t x.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> x.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t y.nr(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> y.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t x.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> x.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
                <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\n\t y.nc(): </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> y.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> 
                <font face='Lucida Console'>)</font>;

            <font color='#009900'>// use the linearly_independent_subset_finder object to select the centers.  So here
</font>            <font color='#009900'>// we show it all the data samples so it can find the best centers.
</font>            linearly_independent_subset_finder<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>lisf</font><font face='Lucida Console'>(</font>kernel, num_centers<font face='Lucida Console'>)</font>;
            <font color='#BB00BB'>fill_lisf</font><font face='Lucida Console'>(</font>lisf, x<font face='Lucida Console'>)</font>;

            <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> num_centers <font color='#5555FF'>=</font> lisf.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

            <font color='#009900'>// fill the K matrix with the output of the kernel for all the center and sample point pairs
</font>            matrix<font color='#5555FF'>&lt;</font>scalar_type,<font color='#979000'>0</font>,<font color='#979000'>0</font>,mem_manager_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>K</font><font face='Lucida Console'>(</font>x.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, num_centers<font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
            <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> r <font color='#5555FF'>=</font> <font color='#979000'>0</font>; r <font color='#5555FF'>&lt;</font> x.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>r<font face='Lucida Console'>)</font>
            <b>{</b>
                <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> c <font color='#5555FF'>=</font> <font color='#979000'>0</font>; c <font color='#5555FF'>&lt;</font> num_centers; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>c<font face='Lucida Console'>)</font>
                <b>{</b>
                    <font color='#BB00BB'>K</font><font face='Lucida Console'>(</font>r,c<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#BB00BB'>kernel</font><font face='Lucida Console'>(</font><font color='#BB00BB'>x</font><font face='Lucida Console'>(</font>r<font face='Lucida Console'>)</font>, lisf[c]<font face='Lucida Console'>)</font>;
                <b>}</b>
                <font color='#009900'>// This last column of the K matrix takes care of the bias term
</font>                <font color='#BB00BB'>K</font><font face='Lucida Console'>(</font>r,num_centers<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>1</font>;
            <b>}</b>

            <font color='#009900'>// compute the best weights by using the pseudo-inverse
</font>            scalar_vector_type <font color='#BB00BB'>weights</font><font face='Lucida Console'>(</font><font color='#BB00BB'>pinv</font><font face='Lucida Console'>(</font>K<font face='Lucida Console'>)</font><font color='#5555FF'>*</font>y<font face='Lucida Console'>)</font>;

            <font color='#009900'>// now put everything into a decision_function object and return it
</font>            <font color='#0000FF'>return</font> decision_function<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>remove_row</font><font face='Lucida Console'>(</font>weights,num_centers<font face='Lucida Console'>)</font>,
                                                   <font color='#5555FF'>-</font><font color='#BB00BB'>weights</font><font face='Lucida Console'>(</font>num_centers<font face='Lucida Console'>)</font>,
                                                   kernel,
                                                   lisf.<font color='#BB00BB'>get_dictionary</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

        <b>}</b>

        kernel_type kernel;
        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> num_centers;

    <b>}</b>; <font color='#009900'>// end of class rbf_network_trainer 
</font>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#0000FF'>typename</font> sample_type<font color='#5555FF'>&gt;</font>
    <font color='#0000FF'><u>void</u></font> <b><a name='swap'></a>swap</b> <font face='Lucida Console'>(</font>
        rbf_network_trainer<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> a,
        rbf_network_trainer<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> b
    <font face='Lucida Console'>)</font> <b>{</b> a.<font color='#BB00BB'>swap</font><font face='Lucida Console'>(</font>b<font face='Lucida Console'>)</font>; <b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<b>}</b>

<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_RBf_NETWORK_
</font>

</pre></body></html>